{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "68397145-e13d-4e6d-bb0a-35944409c333",
   "metadata": {},
   "source": [
    "# Summarization Scoring using AWS Bedrock\n",
    "\n",
    "In this example notebook, we will be comparing how different models in AWS Bedrock perform at summarization tasks using **Arthur Bench**. \n",
    "\n",
    "The overall summarization comparison is setup as a **Bench TestSuite**, and each model is compared head-to-head against every other model\n",
    "in a **Bench TestRun**. Bench provides the ability to view these comparisons in the provided User Interface, as well as access statistics\n",
    "from the Test and TestRuns themselves for further analysis. For example, in the notebook below we provide a means for using the ELO Scoring\n",
    "Algorithm to determine which model performs best at this summarization task. \n",
    "                                                                                                                 \n",
    "The task is to summarize 49 News Articles, and comparison is done using ChatGPT 3.5 Turbo (see summary_quality.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c97dda90-22a7-4a8d-8cf5-e2948d35f8af",
   "metadata": {},
   "source": [
    "## 1. Setup AWS Bedrock Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d375f148-b8f3-4a6d-8a10-f4e96e97ce90",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Authentication is handled using the AWS_PROFILE environment variable. Check the AWS Boto3 documentation and the provided\n",
    "utility library for connecting to Bedrock for additional information\n",
    "\"\"\"\n",
    "from bedrock_client import client\n",
    "\n",
    "bedrock_runtime = client.get_bedrock_client(region=\"us-east-1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b3eeae-afa0-474b-974b-bcaa716d83c8",
   "metadata": {},
   "source": [
    "## 2. Load the data and prepare it for inferencing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfee14bb-017d-4057-a406-1f0cc60e5bee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "articles = []\n",
    "\n",
    "with open(\"data/news_summary/example_summaries.csv\", \"r\") as f:\n",
    "    dr = csv.DictReader(f)\n",
    "    for row in dr: \n",
    "        articles.append(row[\"input_text\"])\n",
    "\n",
    "len(articles)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff9d4a64-1fc0-493c-b7d5-556d4174313e",
   "metadata": {},
   "source": [
    "## 3. Generate inferences (summaries) for the articles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6d18def-3fe4-4eba-aacc-0882cf123bd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "prompt = \"\"\"\\\n",
    "Summarize the following news document down to its most important points in less than 250 words.\n",
    "{}\n",
    "\"\"\"\n",
    "\n",
    "def generate_summary_from_llama(model_id, article): \n",
    "    body = json.dumps({\"prompt\": prompt.format(article)})\n",
    "    modelId = model_id\n",
    "    accept = \"application/json\"\n",
    "    contentType = \"application/json\"\n",
    "    \n",
    "    response = bedrock_runtime.invoke_model(\n",
    "        body=body, modelId=modelId, accept=accept, contentType=contentType\n",
    "    )\n",
    "    response_body = json.loads(response.get(\"body\").read())\n",
    "    return response_body.get(\"generation\")\n",
    "\n",
    "\n",
    "def generate_summary_from_titan(model_id, article):\n",
    "    body = json.dumps({\"inputText\": prompt.format(article)})\n",
    "    modelId = model_id\n",
    "    accept = \"application/json\"\n",
    "    contentType = \"application/json\"\n",
    "    \n",
    "    response = bedrock_runtime.invoke_model(\n",
    "        body=body, modelId=modelId, accept=accept, contentType=contentType\n",
    "    )\n",
    "    response_body = json.loads(response.get(\"body\").read())\n",
    "    return response_body.get(\"results\")[0].get(\"outputText\")\n",
    "\n",
    "\n",
    "def generate_summary_from_claude(model_id, article):\n",
    "    body = json.dumps({\n",
    "        \"prompt\": f\"Human:\\n{prompt.format(article)}\\nAssistant:\\n\",\n",
    "        \"max_tokens_to_sample\": 4000,  # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html\n",
    "    })\n",
    "    modelId = model_id\n",
    "    accept = \"application/json\"\n",
    "    contentType = \"application/json\"\n",
    "    \n",
    "    response = bedrock_runtime.invoke_model(\n",
    "        body=body, modelId=modelId, accept=accept, contentType=contentType\n",
    "    )\n",
    "    response_body = json.loads(response.get(\"body\").read())\n",
    "    return response_body.get(\"completion\")    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f76d25c-7241-48db-9c82-e00335d6da29",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = (\n",
    "    \"meta.llama2-13b-chat-v1\",\n",
    "    \"meta.llama2-70b-chat-v1\",\n",
    "    \"amazon.titan-text-lite-v1\",\n",
    "    \"anthropic.claude-v2\",\n",
    "    \"anthropic.claude-v1\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea41fa3a-13c8-4c07-89f5-6b5e2e58edcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "summaries = defaultdict(list)\n",
    "summaries[\"input\"] = articles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b94c7974-410e-4c7c-b213-9555cc8c0871",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in [m for m in models if m not in summaries.keys()]:\n",
    "    summaries.setdefault(model, [])\n",
    "    for i, article in enumerate(articles):         \n",
    "        try:\n",
    "            print(f\"Generating summary for article {i} using model {model}\")\n",
    "            if \"meta\" in model:\n",
    "                summary = generate_summary_from_llama(model, article)\n",
    "            elif \"amazon\" in model:\n",
    "                summary = generate_summary_from_titan(model, article)\n",
    "            elif \"claude\" in model:\n",
    "                summary = generate_summary_from_claude(model, article)\n",
    "            else:\n",
    "                print(f\"Unable to determine what {model} is\")\n",
    "                continue\n",
    "                \n",
    "            summaries[model].append(summary)\n",
    "        except:\n",
    "            print(f\"Couldn't generate summary for article {i} using model {model}\")\n",
    "            summaries[model].append(\"Unable to summarize\")\n",
    "            continue"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7649fcb0-7e11-4002-9bed-e9c2185112a1",
   "metadata": {},
   "source": [
    "### 3.5 Save (and load) inferences to a pickle file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aad5571-e159-4559-883c-06eeb9f81f6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"summaries.pkl\", \"wb\") as f:\n",
    "    pickle.dump(summaries, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c32024ed-aa38-42bf-81d7-805dc7d68c58",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('summaries.pkl', 'rb') as f:\n",
    "    summaries = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54947f69-7b97-467f-bf09-86eac0c8198d",
   "metadata": {},
   "source": [
    "## 4. Setup Bench TestSuites for each model and run Bench TestRuns\n",
    "\n",
    "The below sets up TestSuites + Runs for each unique combination so we have the ability to rank the models in a round-robin tournament for ELO rating. \n",
    "\n",
    "See the [Quickstart Guide](https://bench.readthedocs.io/en/latest/quickstart.html#view-results-in-local-ui) for additional information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e53792-7c88-41d4-a778-29aeab195f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from arthur_bench.run.testsuite import TestSuite\n",
    "from itertools import combinations\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "# Make sure you've set the BENCH_FILE_DIR environment variable\n",
    "\n",
    "# The summary_quality scorer uses gpt-3.5-turbo to score the summary\n",
    "# so make sure that your OPENAI_API_KEY environment variable is set\n",
    "\n",
    "combos = list(combinations(models, 2))\n",
    "d_combos = defaultdict(list)\n",
    "for m1, m2 in combos:\n",
    "    d_combos[m1].append(m2)\n",
    "\n",
    "print(d_combos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec69abd6-5aa6-4a40-9f79-00f9505fcb59",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_test_suite(model):\n",
    "    print(f\"Creating test suite for model {model}\")\n",
    "    return TestSuite(\n",
    "        f\"News Summarization using {model} as reference\", \n",
    "        'summary_quality',\n",
    "        input_text_list=summaries[\"input\"],\n",
    "        reference_output_list=summaries[model]\n",
    "    )    \n",
    "\n",
    "suites = {}\n",
    "for model in d_combos.keys():\n",
    "    suites[model] = create_test_suite(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09afb482-cbb7-43fe-8f3b-804bd0691d2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_test_suite(suite, ref_model, cand_model):\n",
    "    run = suite.run(\n",
    "        run_name=f\"{ref_model}_vs_{cand_model}\",\n",
    "        candidate_output_list=summaries[cand_model]\n",
    "    )\n",
    "    return run\n",
    "\n",
    "runs = defaultdict(dict)\n",
    "for ref_model, suite in suites.items():\n",
    "    for cand_model in d_combos[ref_model]:\n",
    "        runs[ref_model][cand_model] = run_test_suite(suite, ref_model, cand_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1f1185e4-b39d-4107-bbc4-3aedc449cf46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model anthropic.claude-v2: Rating 1345.4153155340496\n",
      "Model meta.llama2-70b-chat-v1: Rating 1079.6979024278733\n",
      "Model amazon.titan-text-lite-v1: Rating 898.4211374051577\n",
      "Model anthropic.claude-v1: Rating 864.5103513782595\n",
      "Model meta.llama2-13b-chat-v1: Rating 811.9552932546615\n"
     ]
    }
   ],
   "source": [
    "from elo import Implementation\n",
    "\n",
    "i = Implementation()\n",
    "for model in models:\n",
    "    i.addPlayer(model)\n",
    "\n",
    "for ref_model in runs.keys():\n",
    "    for cand_model, run in runs[ref_model].items():\n",
    "        for test_case in run.test_cases:\n",
    "            score = test_case.score\n",
    "            if score == 1.0: \n",
    "                i.recordMatch(ref_model, cand_model, winner=cand_model)\n",
    "            if score == 0.5: \n",
    "                i.recordMatch(ref_model, cand_model, draw=True)\n",
    "            if score == 0.0:\n",
    "                i.recordMatch(ref_model, cand_model, winner=ref_model)\n",
    "            else:\n",
    "                pass\n",
    "\n",
    "sorted_list = sorted(i.getRatingList(), key=lambda x: x[1], reverse=True)\n",
    "\n",
    "# Output the sorted list\n",
    "for model, score in sorted_list:\n",
    "    print(f\"Model {model}: Rating {score}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b321949-059f-495e-8bb7-7d4abb96507b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
